﻿#pragma kernel Initialize
#pragma kernel Project
#pragma kernel Apply

#include "SolverParameters.cginc"
#include "ContactHandling.cginc"
#include "ColliderDefinitions.cginc"
#include "CollisionMaterial.cginc"
#include "Simplex.cginc"
#include "AtomicDeltas.cginc"


StructuredBuffer<int> particleIndices;

StructuredBuffer<int> simplices;
StructuredBuffer<float> invMasses;
StructuredBuffer<float> invRotationalMasses;
StructuredBuffer<float4> principalRadii;
StructuredBuffer<float4> velocities;
StructuredBuffer<float4> prevPositions;
StructuredBuffer<float4> fluidInterface;
StructuredBuffer<quaternion> prevOrientations;
StructuredBuffer<quaternion> orientations;

RWStructuredBuffer<float4> positions;
RWStructuredBuffer<float4> deltas;
RWStructuredBuffer<float4> userData;

// Vulkan workaround: don't declare a RW array after a counter/append one (particleContacts) since the counter overlaps the first entry in the next array.
RWStructuredBuffer<contactMasses> effectiveMasses;
RWStructuredBuffer<contact> particleContacts;
StructuredBuffer<uint> dispatchBuffer;

// Variables set from the CPU
uint particleCount;
float substepTime;
float sorFactor;

[numthreads(128, 1, 1)]
void Initialize (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= dispatchBuffer[3]) return;
    
    int simplexSizeA;
    int simplexSizeB;
    int simplexStartA = GetSimplexStartAndSize(particleContacts[i].bodyA, simplexSizeA);
    int simplexStartB = GetSimplexStartAndSize(particleContacts[i].bodyB, simplexSizeB);

    float4 simplexVelocityA = float4(0,0,0,0);
    float4 simplexPrevPositionA = float4(0,0,0,0);
    quaternion simplexPrevOrientationA = quaternion(0, 0, 0, 0);
    float simplexRadiusA = 0;
    float simplexInvMassA = 0;
    float simplexInvRotationalMassA = 0;

    float4 simplexVelocityB = float4(0,0,0,0);
    float4 simplexPrevPositionB = float4(0,0,0,0);
    quaternion simplexPrevOrientationB = quaternion(0, 0, 0, 0);
    float simplexRadiusB = 0;
    float simplexInvMassB = 0;
    float simplexInvRotationalMassB = 0;

    int j = 0;
    for (j = 0; j < simplexSizeA; ++j)
    {
        int particleIndex = simplices[simplexStartA + j];
        simplexVelocityA += velocities[particleIndex] * particleContacts[i].pointA[j];
        simplexPrevPositionA += prevPositions[particleIndex] * particleContacts[i].pointA[j];
        simplexPrevOrientationA += prevOrientations[particleIndex] * particleContacts[i].pointA[j];
        simplexInvMassA += invMasses[particleIndex] * particleContacts[i].pointA[j];
        simplexInvRotationalMassA += invRotationalMasses[particleIndex] * particleContacts[i].pointA[j];
        simplexRadiusA += EllipsoidRadius(particleContacts[i].normal, prevOrientations[particleIndex], principalRadii[particleIndex].xyz) * particleContacts[i].pointA[j];
    }

    for (j = 0; j < simplexSizeB; ++j)
    {
        int particleIndex = simplices[simplexStartB + j];
        simplexVelocityB += velocities[particleIndex] * particleContacts[i].pointB[j];
        simplexPrevPositionB += prevPositions[particleIndex] * particleContacts[i].pointB[j];
        simplexPrevOrientationB += prevOrientations[particleIndex] * particleContacts[i].pointB[j];
        simplexInvMassB += invMasses[particleIndex] * particleContacts[i].pointB[j];
        simplexInvRotationalMassB += invRotationalMasses[particleIndex] * particleContacts[i].pointB[j];
        simplexRadiusB += EllipsoidRadius(particleContacts[i].normal, prevOrientations[particleIndex], principalRadii[particleIndex].xyz) * particleContacts[i].pointB[j];
    }

    simplexPrevPositionA.w = 0;
    simplexPrevPositionB.w = 0;

    // update contact distance
    float4 contactPointA = simplexPrevPositionA - particleContacts[i].normal * simplexRadiusA;
    float4 contactPointB = simplexPrevPositionB + particleContacts[i].normal * simplexRadiusB;

    particleContacts[i].dist = dot(contactPointA - contactPointB, particleContacts[i].normal);

    // update contact basis:
    CalculateBasis(simplexVelocityA - simplexVelocityB, particleContacts[i].normal,particleContacts[i].tangent);

    // update contact masses:
    int aMaterialIndex = collisionMaterialIndices[simplices[simplexStartA]];
    int bMaterialIndex = collisionMaterialIndices[simplices[simplexStartB]];
    bool rollingContacts = (aMaterialIndex >= 0 ? collisionMaterials[aMaterialIndex].rollingContacts > 0 : false) |
                           (bMaterialIndex >= 0 ? collisionMaterials[bMaterialIndex].rollingContacts > 0 : false);

    float4 invInertiaTensorA = 1.0/(GetParticleInertiaTensor(simplexRadiusA, simplexInvRotationalMassA) + FLOAT4_EPSILON);
    float4 invInertiaTensorB = 1.0/(GetParticleInertiaTensor(simplexRadiusB, simplexInvRotationalMassB) + FLOAT4_EPSILON);

    float4 bitangent = GetBitangent(particleContacts[i]);
    CalculateContactMassesA(simplexInvMassA, invInertiaTensorA, simplexPrevPositionA, simplexPrevOrientationA, contactPointA, rollingContacts, particleContacts[i].normal,particleContacts[i].tangent,bitangent, effectiveMasses[i].normalInvMassA, effectiveMasses[i].tangentInvMassA, effectiveMasses[i].bitangentInvMassA);
    CalculateContactMassesB(simplexInvMassB, invInertiaTensorB, simplexPrevPositionB, simplexPrevOrientationB, contactPointB, rollingContacts, particleContacts[i].normal,particleContacts[i].tangent,bitangent, effectiveMasses[i].normalInvMassB, effectiveMasses[i].tangentInvMassB, effectiveMasses[i].bitangentInvMassB);
}

[numthreads(128, 1, 1)]
void Project (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= dispatchBuffer[3]) return;

    int simplexSizeA;
    int simplexSizeB;
    int simplexStartA = GetSimplexStartAndSize(particleContacts[i].bodyA, simplexSizeA);
    int simplexStartB = GetSimplexStartAndSize(particleContacts[i].bodyB, simplexSizeB);

    // Combine collision materials (use material from first particle in simplex)
    collisionMaterial material = CombineCollisionMaterials(collisionMaterialIndices[simplices[simplexStartA]], collisionMaterialIndices[simplices[simplexStartB]]);

    float4 simplexPositionA = FLOAT4_ZERO, simplexPositionB = FLOAT4_ZERO;
    float simplexRadiusA = 0, simplexRadiusB = 0;
    float4 simplexUserDataA = FLOAT4_ZERO, simplexUserDataB = FLOAT4_ZERO;
    float miscibility = 0;

    int j = 0;
    for (j = 0; j < simplexSizeA; ++j)
    {
        int particleIndex = simplices[simplexStartA + j];
        simplexPositionA += positions[particleIndex] * particleContacts[i].pointA[j];
        simplexRadiusA += EllipsoidRadius(particleContacts[i].normal, orientations[particleIndex], principalRadii[particleIndex].xyz) * particleContacts[i].pointA[j];
        simplexUserDataA += userData[particleIndex] * particleContacts[i].pointA[j];
        miscibility += fluidInterface[particleIndex].w * particleContacts[i].pointA[j];
    }
    for (j = 0; j < simplexSizeB; ++j)
    {
        int particleIndex = simplices[simplexStartB + j];
        simplexPositionB += positions[particleIndex] * particleContacts[i].pointB[j];
        simplexRadiusB += EllipsoidRadius(particleContacts[i].normal, orientations[particleIndex], principalRadii[particleIndex].xyz) * particleContacts[i].pointA[j];
        simplexUserDataB += userData[particleIndex] * particleContacts[i].pointB[j];
        miscibility += fluidInterface[particleIndex].w * particleContacts[i].pointB[j];
    }

    simplexPositionA.w = 0;
    simplexPositionB.w = 0;

    float4 posA = simplexPositionA - particleContacts[i].normal * simplexRadiusA;
    float4 posB = simplexPositionB + particleContacts[i].normal * simplexRadiusB;

    float normalInvMass = effectiveMasses[i].normalInvMassA + effectiveMasses[i].normalInvMassB;

    // adhesion:
    float lambda = SolveAdhesion(particleContacts[i], normalInvMass, posA, posB, material.stickDistance, material.stickiness, substepTime);

    lambda += SolvePenetration(particleContacts[i], normalInvMass, posA, posB, maxDepenetration * substepTime);

    if (abs(lambda) > EPSILON)
    {
        float shock = shockPropagation * dot(particleContacts[i].normal.xyz, normalizesafe(gravity));
        float4 delta = lambda * particleContacts[i].normal;
               
        float baryScale = BaryScale(particleContacts[i].pointA);
        for (j = 0; j < simplexSizeA; ++j)
        {
            int particleIndex = simplices[simplexStartA + j];
            float4 delta1 = delta * invMasses[particleIndex] * particleContacts[i].pointA[j] * baryScale * (1 - shock);
            AtomicAddPositionDelta(particleIndex, delta1);
        }

        baryScale = BaryScale(particleContacts[i].pointB);
        for (j = 0; j < simplexSizeB; ++j)
        {
            int particleIndex = simplices[simplexStartB + j];
            float4 delta2 = -delta * invMasses[particleIndex] * particleContacts[i].pointB[j] * baryScale * (1 + shock);
            AtomicAddPositionDelta(particleIndex, delta2);
        }
    }

    // property diffusion:
    if (particleContacts[i].dist < collisionMargin)
    {
        float diffusionSpeed = miscibility * 0.5 * substepTime;
        float4 userDelta = (simplexUserDataB - simplexUserDataA) * diffusionMask * diffusionSpeed;

        for (j = 0; j < simplexSizeA; ++j)
            AtomicAddOrientationDelta(simplices[simplexStartA + j], userDelta * particleContacts[i].pointA[j]);
           
        for (j = 0; j < simplexSizeB; ++j)
            AtomicAddOrientationDelta(simplices[simplexStartB + j], -userDelta * particleContacts[i].pointB[j]);
    }
}

[numthreads(128, 1, 1)]
void Apply (uint3 id : SV_DispatchThreadID) 
{
    unsigned int threadIndex = id.x;

    if (threadIndex >= particleCount) return;

    int p = particleIndices[threadIndex];
 
    ApplyPositionDelta(positions, p, sorFactor);
    ApplyUserDataDelta(userData, p);
}


